---
layout: post
date: 2024-10-11
title: "TCP Server in Zig - Part 4 - Multithreading"
description: "Upgrading our simple server to use multiple worker threads"
tags: [zig]
---

<p>We finished Part 1 with a simple single-threaded server, which we could describe as:</p>
<ol>
  <li>Create our socket
  <li>Bind it to an address
  <li>Put it in "server" mode (i.e. call <code>listen</code> on it)
  <li>Accept a connection
  <li>Application logic involving reading/writing to the socket
  <li>Close the connection
  <li>Goto step 4
</ol>

<p>While this approach is useful for getting familiar with various socket APIs and networking concepts, it isn't practical for most real world application. The issue is that it can only service 1 client at a time. Any additional clients will block (or fail to connect) until the currently connected client is finished and the server calls <code>accept</code> again to process the next client in line.</p>

<p>There are a few different ways to deal with this problem, with some of these being complementary. But a common place to start is to move step 5 and 6 from our above list into their own thread:</p>

{% highlight zig %}
const std = @import("std");
const net = std.net;
const posix = std.posix;

pub fn main() !void {
    // All of the same existing code from before, to create, bind and listen on
    // the socket. A complete working example is given at the end of this post.

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            // Rare that this happens, but in later parts we'll
            // see examples where it does.
            std.debug.print("error accept: {}\n", .{err});
            continue;
        };
        const thread = try std.Thread.spawn(.{}, run, .{socket, client_address});
        thread.detach();
    }
}

fn run(socket: posix.socket_t, address: std.net.Address) !void {
    defer posix.close(socket);

    const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
    try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
    try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

    var buf: [1024]u8 = undefined;
    var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

    while (true) {
        const msg = try reader.readMessage();
        std.debug.print("Got: {s}\n", .{msg});
    }
}
{% endhighlight %}

<p><code>Thread.spawn</code> is used to launch a new thread. The first parameter is a <code>SpawnConfig</code> which allows us to set a stack size (defaults to 16MB). The second parameter is the function to run and the last parameter are the arguments to pass to the function. The first parameter we pass is the <code>socket</code> and the second is the <code>client_address</code>. As a small tweak, we can create a <code>Client</code> to start encompassing our client-handling logic:</p>

{% highlight zig %}
const Client = struct {
    socket: posix.socket_t,
    address: std.net.Address,

    fn handle(self: Client) !void {
        const socket = self.socket;

        defer posix.close(socket);
        std.debug.print("{} connected\n", .{self.address});

        const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

        var buf: [1024]u8 = undefined;
        var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

        while (true) {
            const msg = try reader.readMessage();
            std.debug.print("Got: {s}\n", .{msg});
        }
    }

};
{% endhighlight %}

<p>This is the same code as above, just organized a little better. To call this from our main listening thread, the change is small:</p>

{% highlight zig %}
while (true) {
    var client_address: net.Address = undefined;
    var client_address_len: posix.socklen_t = @sizeOf(net.Address);
    const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
        //  same code as before
    };

    const client = Client{ .socket = socket, .address = client_address };
    const thread = try std.Thread.spawn(.{}, Client.handle, .{client});
    thread.detach();
}
{% endhighlight %}

<p>The <code>client</code> variable in the above code is scoped to, and only valid within, the <code>while</code> block. Care must be taken with respect to the values we pass to <code>Thread.spawn</code>; the spawned thread is independent from the location where it was spawned from. If we changed the above to pass a reference to client, <code>&client</code>, we'd be in undefined behavior territory. Our code would probably crash, but before that happens, it could send the wrong message to the wrong client.</p>

<p><code>spawn</code> returns an <code>std.Thread</code>. You always want to call either <code>join</code> or <code>detach</code> on this. <code>join</code> will block the caller until the thread exits. <code>detach</code> is used to indicate that we never intend to call <code>join</code> on the thread. That might seem odd, but by calling <code>detach</code> we signal that we don't intend to call <code>join</code> on the thread. In doing so, the threading implementation (which is platform specific) can release any data associated with the thread upon thread exit. If you don't call <code>detach</code>, some state has to stick around because you might call <code>join</code> at some point in the future.</p>

<p>If we replaced our call to <code>detach</code> with <code>join</code>, our implementation would behave a lot like our initial single-threaded example. Our main thread would accept the connection and spawn a new thread, but then would block on the call to <code>join</code> until the newly spawned thread terminated.</p>

<h3 id=threadpool><a href="#threadpool" aria-hidden=true>ThreadPool</a></h3>
<p>Spawning a thread per connection is a simple modification to our existing code, but our current implementation will spawn as many threads as there are connections. That could be an issue if we're expecting thousands of concurrent connection. There's no one-size-fits-all rule for the number of threads a system can support. It'll depend on the hardware we're running and, critically, what those threads are doing. If the threads are doing heavy CPU work, having more threads than CPU cores could hurt performance.</p>

<p>One common solution to this problem is to use a ThreadPool. Let's modify our code, then look at how it works:</p>

{% highlight zig %}
pub fn main() !void {
    // our thread pool needs an allocator
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    var pool: std.Thread.Pool = undefined;
    try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});

    // ...
    // all the same socket setup stuff as before
    // removed to keep this snippet more readable..
    // ...

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            //  same code as before
        };

        const client = Client{ .socket = socket, .address = client_address };
        try pool.spawn(Client.handle, .{client});
    }
}
{% endhighlight %}

<p>We also need to make a slight change to our <code>Client.handle</code>, but let's review the above first. For the first time in this series, we need an allocator. You might find it odd that we define <code>pool</code> and then pass its address into <code>ThreadPool.init</code>. This is because <code>init</code> needs a stable pointer to the <code>ThreadPool</code> instance, and this provides the caller with the flexibility to decide how that should be achieved (the standard library doesn't like calling <code>allocator.create(Self)</code> within <code>init</code> functions).</p>

<p>Previously, our <code>Client.handle</code> returned <code>!void</code>, which <code>Thread.spawn</code> handled. But our <code>Thread.Pool</code> doesn't. We need to change <code>handle</code> to return <code>void</code> instead of <code>!void</code>. Since I like the ergonomics of <code>try</code>, I think the best option is to make <code>handle</code> a wrapper around a function that returns an error:</p>

{% highlight zig %}
const Client = struct {
    // ...

    fn handle(self: Client) void {
        self._handle() catch |err| switch (err) {
            error.Closed => {},
            else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
        };
    }

    fn _handle(self: Client) !void {
        // same handle as before
    }
};
{% endhighlight %}

<p>One issue with Zig's <code>Thread.Pool</code> is that it can be memory-intensive. This is because it's generic, each invocation can be given a different function to run and parameters to use. This requires creating a <a href="/Closures-in-Zig/">closure</a> around the arguments. My understand is that Zig's ThreadPool was designed for long-running jobs, where the initial overhead was a relatively small cost compared to the overall work being done. If you're running many short-lived jobs, like processing HTTP requests, you might want to look at <a href="/Writing-a-Task-Scheduler-in-Zig/">something more optimized for that use-case</a>.</p>

<p>Finally, in the above code, we initialized our <code>Thread.Pool</code> with <code>n_jobs</code> set to 64. This is the number of workers our pool will have and represents the maximum number of concurrent jobs it will be able to run. This limits the number of concurrent connections our system can support to 64. One some hardware, for some workloads, you could set this number much higher.</p>

<h3 id=conclusion><a href="#conclusion" aria-hidden=true>Conclusion</a></h3>
<p>Our 64-worker ThreadPool is considerably better than our single-threaded implementation without adding much complexity. The most significant change we made, introducing the <code>Client</code>, was housecleaning unrelated to our multithreaded initiative. You can forgo the ThreadPool and call <code>Thread.spawn</code> directly, but you should only do this in more controlled environments - where the number of clients is known (and relatively small) and you aren't exposing your service to the public internet.</p>

<p>In the next part, we'll take our first look at a nonblocking implementation, which is how we'll be able to support a much larger number of connections. However, and as we'll see, the nonblocking implementation doesn't have to be an alternative to using a ThreadPool. The two can compliment each other, so everything we've learnt here will continue to be useful.</p>

<div class=pager>
  <a class=prev href=/TCP-Server-In-Zig-Part-3-Minimizing-Writes-and-Reads/>Minimizing Writes & Reads</a>
  <a class=next href=/TCP-Server-In-Zig-Part-5a-Poll/>Poll (Part 1)</a>
</div>

<h3 id=appendix-a><a href="#appendix-a" aria-hidden=true>Appendix A - Code</a></h3>
<p>Here's a complete working server implementation:</p>

{% highlight zig %}
// Multithreaded server with a ThreadPool
const std = @import("std");
const net = std.net;
const posix = std.posix;

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    var pool: std.Thread.Pool = undefined;
    try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});

    const address = try std.net.Address.parseIp("127.0.0.1", 5882);

    const tpe: u32 = posix.SOCK.STREAM;
    const protocol = posix.IPPROTO.TCP;
    const listener = try posix.socket(address.any.family, tpe, protocol);
    defer posix.close(listener);

    try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
    try posix.bind(listener, &address.any, address.getOsSockLen());
    try posix.listen(listener, 128);

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            std.debug.print("error accept: {}\n", .{err});
            continue;
        };

        const client = Client{ .socket = socket, .address = client_address };
        try pool.spawn(Client.handle, .{client});
    }
}

const Client = struct {
    socket: posix.socket_t,
    address: std.net.Address,

    fn handle(self: Client) void {
        defer posix.close(self.socket);
        self._handle() catch |err| switch (err) {
            error.Closed => {},
            error.WouldBlock => {}, // read or write timeout
            else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
        };
    }

    fn _handle(self: Client) !void {
        const socket = self.socket;
        std.debug.print("[{}] connected\n", .{self.address});

        const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

        var buf: [1024]u8 = undefined;
        var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

        while (true) {
            const msg = try reader.readMessage();
            std.debug.print("[{}] sent: {s}\n", .{self.address, msg});
        }
    }
};

const Reader = struct {
    buf: []u8,
    pos: usize = 0,
    start: usize = 0,
    socket: posix.socket_t,

    fn readMessage(self: *Reader) ![]u8 {
        var buf = self.buf;

        while (true) {
            if (try self.bufferedMessage()) |msg| {
                return msg;
            }
            const pos = self.pos;
            const n = try posix.read(self.socket, buf[pos..]);
            if (n == 0) {
                return error.Closed;
            }
            self.pos = pos + n;
        }
    }

    fn bufferedMessage(self: *Reader) !?[]u8 {
        const buf = self.buf;
        const pos = self.pos;
        const start = self.start;

        std.debug.assert(pos >= start);
        const unprocessed = buf[start..pos];
        if (unprocessed.len < 4) {
            self.ensureSpace(4 - unprocessed.len) catch unreachable;
            return null;
        }

        const message_len = std.mem.readInt(u32, unprocessed[0..4], .little);

        // the length of our message + the length of our prefix
        const total_len = message_len + 4;

        if (unprocessed.len < total_len) {
            try self.ensureSpace(total_len);
            return null;
        }

        self.start += total_len;
        return unprocessed[4..total_len];
    }

    fn ensureSpace(self: *Reader, space: usize) error{BufferTooSmall}!void {
        const buf = self.buf;
        if (buf.len < space) {
            return error.BufferTooSmall;
        }

        const start = self.start;
        const spare = buf.len - start;
        if (spare >= space) {
            return;
        }

        const unprocessed = buf[start..self.pos];
        std.mem.copyForwards(u8, buf[0..unprocessed.len], unprocessed);
        self.start = 0;
        self.pos = unprocessed.len;
    }
};
{% endhighlight %}

<p>Which you can test with this dummy client:</p>

{% highlight zig %}
// Test client
const std = @import("std");
const posix = std.posix;

pub fn main() !void {
    const address = try std.net.Address.parseIp("127.0.0.1", 5882);

    const tpe: u32 = posix.SOCK.STREAM;
    const protocol = posix.IPPROTO.TCP;
    const socket = try posix.socket(address.any.family, tpe, protocol);
    defer posix.close(socket);

    try posix.connect(socket, &address.any, address.getOsSockLen());
    try writeMessage(socket, "Hello World");
    try writeMessage(socket, "It's Over 9000!!");
}

fn writeMessage(socket: posix.socket_t, msg: []const u8) !void {
    var buf: [4]u8 = undefined;
    std.mem.writeInt(u32, &buf, @intCast(msg.len), .little);

    var vec = [2]posix.iovec_const{
      .{ .len = 4, .base = &buf },
      .{ .len = msg.len, .base = msg.ptr },
    };
    try writeAllVectored(socket, &vec);
}

fn writeAllVectored(socket: posix.socket_t, vec: []posix.iovec_const) !void {
    var i: usize = 0;
    while (true) {
        var n = try posix.writev(socket, vec[i..]);
        while (n >= vec[i].len) {
            n -= vec[i].len;
            i += 1;
            if (i >= vec.len) return;
        }
        vec[i].base += n;
        vec[i].len -= n;
    }
}
{% endhighlight %}

<div class=pager>
  <a class=prev href=/TCP-Server-In-Zig-Part-3-Minimizing-Writes-and-Reads/>Minimizing Writes & Reads</a>
  <a class=next href=/TCP-Server-In-Zig-Part-5a-Poll/>Poll (Part 1)</a>
</div>
